import torch
from torch import nn
# from mmcls.models.backbones.rednet import RedNet
from mmclassification.mmcls.models.backbones.rednet import RedNet

class redNet(nn.Module):
    def __init__(self,num_classes=1):
        super(redNet, self).__init__()
        self.model = RedNet(101)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # self.eff = nn.Sequential()
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 1)

    def forward(self, data):
        # print('x', x.shape)#[32, 3, 224, 224]
        # print('x_ref', x_ref.shape)#[32, 3, 224, 224]

        #pretraining code
        [x,x_ref] = data
        x_fea = self.model(x)
        ref_fea = self.model(x_ref)
        temp = x_fea - ref_fea
        temp = self.avgpool(temp)
        temp = torch.flatten(temp, 1)
        q = torch.nn.functional.relu(self.fc1(temp))
        # q = torch.nn.functional.dropout(q)
        q = torch.nn.functional.relu(self.fc2(q))
        q = self.fc3(q)
        return q




